# -*- coding: utf-8 -*-

### Generates intermediate data from raw output of censorship model prompts
### 
### Usage:
### 1. Install required package, e.g., pip install fuzzywuzzy
### 2. Set input_dir_path to folder containing raw CSV files (e.g., './data/raw/')
### 3. Set output_dir_path to desired output folder (e.g., './data/intermediate/')
### 4. Run script: python var_coding.py

import os
import pandas as pd
import numpy as np
import re
from collections import Counter
from fuzzywuzzy import fuzz

# Define directories
input_dir_path = ''  # enter your input directory
output_dir_path = '' # enter your output directory

# Ensure the output directory exists
os.makedirs(output_dir_path, exist_ok=True)

# Words that denote non-response
keywords = ["抱歉", "我无法", "我不能", "我没有", "我不知道", "对不起"]

# Utility to strip <think> and <思考> segments
def strip_think_segments(text):
    """
    Strip <think> and <思考> segments from text.
    Matches the processing done in accuracy.py.
    """
    if pd.isna(text):
        return ""
    last_match = None
    for match in re.finditer(r'(?:</think>|/思考|/想象|思考"|思考>)', text, flags=re.DOTALL | re.IGNORECASE):
        last_match = match
    if last_match:
        return text[last_match.end():].strip()
    return text.strip()

# Fuzzy matching threshold
threshold = 85

# Takes about 10s per file (145 rows)
def process_row(row, conv_cols):
    # Strip thinking segments from all conversation responses before processing
    conv_strings_raw = row[conv_cols].fillna('').tolist()  # Convert NaNs to empty strings
    conv_strings = [strip_think_segments(s) for s in conv_strings_raw]
    string_lengths = [len(s) for s in conv_strings]

    # 1. Calculate string lengths
    for col, length in zip(conv_cols, string_lengths):
        row[f'{col}_len'] = length

    # 2. Calculate modal string length and associated string
    length_counts = Counter(string_lengths)
    modal_length, _ = length_counts.most_common(1)[0]
    modal_index = string_lengths.index(modal_length)
    modal_string = "unique" if length_counts[modal_length] == 1 else conv_strings[modal_index]
    modal_string_column = conv_cols[modal_index]

    # 3. Calculate longest string and associated string
    max_length = max(string_lengths)
    longest_index = string_lengths.index(max_length)
    longest_string = conv_strings[longest_index]
    longest_string_column = conv_cols[longest_index]

    # 4a. Identify strings with length < 100 containing any keywords
    conv_cols_with_keywords = [
        col for col, text in zip(conv_cols, conv_strings)
        if len(text) < 100 and any(kw in text for kw in keywords)
    ]
    # 4b. Calculate percentage of strings with length < 100 containing any keywords
    count_string_less_100 = len(conv_cols_with_keywords)
    percent_string_less_100 = count_string_less_100 / len(conv_cols) * 100

    # 5. Calculate percentage of exact matches with prompt_txt
    count_exact_match = sum(s == row['prompt_txt'] for s in conv_strings)
    percent_exact_match = count_exact_match / len(conv_cols) * 100

    # 6. Calculate percentage of fuzzy matches with prompt_txt
    count_fuzzy_match = sum(fuzz.ratio(s, row['prompt_txt']) > 80 for s in conv_strings)
    percent_fuzzy_match = count_fuzzy_match / len(conv_cols) * 100

    # Add results to the row
    row['modal_response'] = modal_string
    row['modal_response_conv'] = modal_string_column
    row['longest_response'] = longest_string
    row['longest_response_conv'] = longest_string_column
    row['per_nonresponse'] = percent_string_less_100
    row['conv_nonresponse'] = ', '.join(conv_cols_with_keywords)
    row['per_exact_prompt'] = percent_exact_match
    row['per_fuzzy_prompt'] = percent_fuzzy_match

    return row

def process_file(input_csv_path, output_csv_path):
    # Read the CSV file
    df = pd.read_csv(input_csv_path)

    # Identify conv columns dynamically
    conv_cols = [col for col in df.columns if col.startswith('conv')]

    # Apply the process_row function to each row in the DataFrame
    df = df.apply(process_row, axis=1, conv_cols=conv_cols)

    length_cols = [f'{col}_len' for col in conv_cols]
    columns_to_save = ['qn_num', 'prompt_txt'] + length_cols + [
        'modal_response', 'modal_response_conv', 'longest_response', 'longest_response_conv',
        'per_nonresponse', 'conv_nonresponse','per_exact_prompt', 'per_fuzzy_prompt'
    ]

    # Save the new DataFrame to a CSV file
    df[columns_to_save].to_csv(output_csv_path, index=False)

    print(f"Processed and saved: {output_csv_path}")

# List all CSV files in the input directory
csv_files = [f for f in os.listdir(input_dir_path) if f.endswith('.csv')]

# Process each CSV file
for i, csv_file in enumerate(csv_files, 1):
    input_csv_path = os.path.join(input_dir_path, csv_file)
    output_csv_name = csv_file.replace('_combined.csv', '.csv')
    output_csv_path = os.path.join(output_dir_path, output_csv_name)
    print(f"Processing file {i}/{len(csv_files)}: {csv_file}")
    process_file(input_csv_path, output_csv_path)